{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MhoQ0WE77laV"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "_ckMIh7O7s6D"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYysdyb-CaWM"
      },
      "source": [
        "# Custom training with TPUs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/r1/tutorials/distribute/tpu_custom_training.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/r1/tutorials/distribute/tpu_custom_training.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2"
      },
      "source": [
        "> Note: This is an archived TF1 notebook. These are configured\n",
        "to run in TF2's \n",
        "[compatibility mode](https://www.tensorflow.org/guide/migrate)\n",
        "but will run in TF1 as well. To use TF1 in Colab, use the\n",
        "[%tensorflow_version 1.x](https://colab.research.google.com/notebooks/tensorflow_version.ipynb)\n",
        "magic."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbVhjPpzn6BM"
      },
      "source": [
        "This tutorial will take you through using [tf.distribute.experimental.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy). This is a new strategy, a part of `tf.distribute.Strategy`, that allows users to easily switch their model to using TPUs. As part of this tutorial, you will create a Keras model and take it through a custom training loop (instead of calling `fit` method).\n",
        "\n",
        "You should be able to understand what is a strategy and why it’s necessary in Tensorflow. This will help you switch between CPU, GPUs, and other device configurations more easily once you understand the strategy framework. To make the introduction easier, you will also make a Keras model that produces a simple convolutional neural network. A Keras model usually is trained in one line of code (by calling its `fit` method), but because some users require additional customization, we showcase how to use custom training loops. Distribution Strategy was originally written by DeepMind -- you can [read the story here](https://deepmind.com/blog/tf-replicator-distributed-machine-learning/)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzLKpmZICaWN"
      },
      "outputs": [],
      "source": [
        "# Import TensorFlow\n",
        "import tensorflow.compat.v1 as tf\n",
        "tf.compat.v1.disable_eager_execution()\n",
        "\n",
        "# Helper libraries\n",
        "import numpy as np\n",
        "import os\n",
        "\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'\n",
        "\n",
        "assert float('.'.join(tf.__version__.split('.')[:2])) >= 1.14, 'Make sure that Tensorflow version is at least 1.14'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7nSiTgSt-RrV"
      },
      "outputs": [],
      "source": [
        "TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MM6W__qraV55"
      },
      "source": [
        "## Create model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "58ff7ew6MK9d"
      },
      "source": [
        "Since you will be working with the [MNIST data](https://en.wikipedia.org/wiki/MNIST_database), which is a collection of 70,000 greyscale images representing digits, you want to be using a convolutional neural network to help us with the labeled image data. You will use the Keras API."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7MqDQO0KCaWS"
      },
      "outputs": [],
      "source": [
        "def create_model(input_shape):\n",
        "  \"\"\"Creates a simple convolutional neural network model using the Keras API\"\"\"\n",
        "  return tf.keras.Sequential([\n",
        "      tf.keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu', input_shape=input_shape),\n",
        "      tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(128, activation=tf.nn.relu),\n",
        "      tf.keras.layers.Dropout(0.2),\n",
        "      tf.keras.layers.Dense(10, activation=tf.nn.softmax),\n",
        "  ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AXoHhrsbdF3"
      },
      "source": [
        "## Loss and gradient"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5mVuLZhbem8d"
      },
      "source": [
        "Since you are preparing to use a custom training loop, you need to explicitly write down the loss and gradient functions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F2VeZUWUj5S4"
      },
      "outputs": [],
      "source": [
        "def loss(model, x, y):\n",
        "  \"\"\"Calculates the loss given an example (x, y)\"\"\"\n",
        "  logits = model(x)\n",
        "  return logits, tf.losses.sparse_softmax_cross_entropy(labels=y, logits=logits)\n",
        "\n",
        "def grad(model, x, y):\n",
        "  \"\"\"Calculates the loss and the gradients given an example (x, y)\"\"\"\n",
        "  logits, loss_value = loss(model, x, y)\n",
        "  return logits, loss_value, tf.gradients(loss_value, model.trainable_variables)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k53F5I_IiGyI"
      },
      "source": [
        "## Main function"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Qb6nDgxiN_n"
      },
      "source": [
        "Previous sections highlighted the most important parts of the tutorial. The following code block gives a complete and runnable example of using TPUStrategy with a Keras model and a custom training loop."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jwJtsCQhHK-E"
      },
      "outputs": [],
      "source": [
        "tf.keras.backend.clear_session()\n",
        "\n",
        "resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_WORKER)\n",
        "tf.tpu.experimental.initialize_tpu_system(resolver)\n",
        "strategy = tf.distribute.experimental.TPUStrategy(resolver)\n",
        "# Load MNIST training and test data\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
        "\n",
        "# All MNIST examples are 28x28 pixel greyscale images (hence the 1\n",
        "# for the number of channels).\n",
        "input_shape = (28, 28, 1)\n",
        "\n",
        "# Only specific data types are supported on the TPU, so it is important to\n",
        "# pay attention to these.\n",
        "# More information:\n",
        "# https://cloud.google.com/tpu/docs/troubleshooting#unsupported_data_type\n",
        "x_train = x_train.reshape(x_train.shape[0], *input_shape).astype(np.float32)\n",
        "x_test = x_test.reshape(x_test.shape[0], *input_shape).astype(np.float32)\n",
        "y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)\n",
        "\n",
        "# The batch size must be divisible by the number of workers (8 workers),\n",
        "# so batch sizes of 8, 16, 24, 32, ... are supported.\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "NUM_EPOCHS = 5\n",
        "\n",
        "train_steps_per_epoch = len(x_train) // BATCH_SIZE\n",
        "test_steps_per_epoch = len(x_test) // BATCH_SIZE"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GPrDC8IfOgCT"
      },
      "source": [
        "## Start by creating objects within the strategy's scope\n",
        "\n",
        "Model creation, optimizer creation, etc. must be written in the context of strategy.scope() in order to use TPUStrategy. \n",
        "\n",
        "Also initialize metrics for the train and test sets. More information: `keras.metrics.Mean` and `keras.metrics.SparseCategoricalAccuracy`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s_suB7CZNw5W"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  model = create_model(input_shape)\n",
        "\n",
        "  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)\n",
        "\n",
        "  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)\n",
        "  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      'training_accuracy', dtype=tf.float32)\n",
        "  test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)\n",
        "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      'test_accuracy', dtype=tf.float32)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d3iLK5ZtO1_R"
      },
      "source": [
        "## Define custom train and test steps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XAF6xfU0N5ID"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  def train_step(inputs):\n",
        "    \"\"\"Each training step runs this custom function which calculates\n",
        "    gradients and updates weights.\n",
        "    \"\"\"\n",
        "    x, y = inputs\n",
        "\n",
        "    logits, loss_value, grads = grad(model, x, y)\n",
        "\n",
        "    update_loss = training_loss.update_state(loss_value)\n",
        "    update_accuracy = training_accuracy.update_state(y, logits)\n",
        "\n",
        "    # Show that this is truly a custom training loop\n",
        "    # Multiply all gradients by 2.\n",
        "    grads = grads * 2\n",
        "\n",
        "    update_vars = optimizer.apply_gradients(\n",
        "        zip(grads, model.trainable_variables))\n",
        "\n",
        "    with tf.control_dependencies([update_vars, update_loss, update_accuracy]):\n",
        "      return tf.identity(loss_value)\n",
        "\n",
        "  def test_step(inputs):\n",
        "    \"\"\"Each training step runs this custom function\"\"\"\n",
        "    x, y = inputs\n",
        "\n",
        "    logits, loss_value = loss(model, x, y)\n",
        "\n",
        "    update_loss = test_loss.update_state(loss_value)\n",
        "    update_accuracy = test_accuracy.update_state(y, logits)\n",
        "\n",
        "    with tf.control_dependencies([update_loss, update_accuracy]):\n",
        "      return tf.identity(loss_value)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AhrK1-yEO7Nf"
      },
      "source": [
        "## Do the training\n",
        "In order to make the reading a little bit easier, the full training loop calls two helper functions, `run_train()` and `run_test()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "or5osuheouVU"
      },
      "outputs": [],
      "source": [
        "def run_train():\n",
        "  # Train\n",
        "  session.run(train_iterator_init)\n",
        "  while True:\n",
        "    try:\n",
        "      session.run(dist_train)\n",
        "    except tf.errors.OutOfRangeError:\n",
        "      break\n",
        "  print('Train loss: {:0.4f}\\t Train accuracy: {:0.4f}%'.format(\n",
        "      session.run(training_loss_result),\n",
        "      session.run(training_accuracy_result) * 100))\n",
        "  training_loss.reset_states()\n",
        "  training_accuracy.reset_states()\n",
        "\n",
        "def run_test():\n",
        "  # Test\n",
        "  session.run(test_iterator_init)\n",
        "  while True:\n",
        "    try:\n",
        "      session.run(dist_test)\n",
        "    except tf.errors.OutOfRangeError:\n",
        "      break\n",
        "  print('Test loss: {:0.4f}\\t Test accuracy: {:0.4f}%'.format(\n",
        "      session.run(test_loss_result),\n",
        "      session.run(test_accuracy_result) * 100))\n",
        "  test_loss.reset_states()\n",
        "  test_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u5LvzAwjN95j"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  training_loss_result = training_loss.result()\n",
        "  training_accuracy_result = training_accuracy.result()\n",
        "  test_loss_result = test_loss.result()\n",
        "  test_accuracy_result = test_accuracy.result()\n",
        "  \n",
        "  config = tf.ConfigProto()\n",
        "  config.allow_soft_placement = True\n",
        "  cluster_spec = resolver.cluster_spec()\n",
        "  if cluster_spec:\n",
        "    config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())\n",
        "\n",
        "  print('Starting training...')\n",
        "\n",
        "  # Do all the computations inside a Session (as opposed to doing eager mode)\n",
        "  with tf.Session(target=resolver.master(), config=config) as session:\n",
        "    all_variables = (\n",
        "        tf.global_variables() + training_loss.variables +\n",
        "        training_accuracy.variables + test_loss.variables +\n",
        "        test_accuracy.variables)\n",
        "    \n",
        "    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(BATCH_SIZE, drop_remainder=True)\n",
        "    train_iterator = strategy.make_dataset_iterator(train_dataset)\n",
        "\n",
        "    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE, drop_remainder=True)\n",
        "    test_iterator = strategy.make_dataset_iterator(test_dataset)\n",
        "    \n",
        "    train_iterator_init = train_iterator.initializer\n",
        "    test_iterator_init = test_iterator.initializer\n",
        "\n",
        "    session.run([v.initializer for v in all_variables])\n",
        "    \n",
        "    dist_train = strategy.experimental_run(train_step, train_iterator).values\n",
        "    dist_test = strategy.experimental_run(test_step, test_iterator).values\n",
        "\n",
        "    # Custom training loop\n",
        "    for epoch in range(0, NUM_EPOCHS):\n",
        "      print('Starting epoch {}'.format(epoch))\n",
        "\n",
        "      run_train()\n",
        "\n",
        "      run_test()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "name": "tpu_custom_training.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
